{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "from torch.autograd import Variable\n",
    "import pickle\n",
    "from torch.utils import data as t_data\n",
    "import torchvision.datasets as datasets\n",
    "from torchvision import transforms\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_transforms = transforms.Compose([transforms.ToTensor()])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "0it [00:00, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 99%|█████████▉| 9846784/9912422 [00:18<00:00, 549474.96it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "0it [00:00, ?it/s]\u001b[A"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "  0%|          | 0/28881 [00:00<?, ?it/s]\u001b[A\n",
      "32768it [00:00, 71965.37it/s]                            \u001b[A\n",
      "\n",
      "0it [00:00, ?it/s]\u001b[A"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw\n",
      "Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "  0%|          | 0/1648877 [00:00<?, ?it/s]\u001b[A\n",
      "  1%|          | 16384/1648877 [00:00<00:13, 125004.73it/s]\u001b[A\n",
      "  2%|▏         | 40960/1648877 [00:00<00:11, 141521.48it/s]\u001b[A\n",
      "  5%|▌         | 90112/1648877 [00:00<00:09, 169658.40it/s]\u001b[A\n",
      "  8%|▊         | 139264/1648877 [00:00<00:07, 210692.94it/s]\u001b[A\n",
      " 11%|█         | 180224/1648877 [00:00<00:06, 213049.37it/s]\u001b[A\n",
      " 14%|█▍        | 237568/1648877 [00:01<00:05, 257461.14it/s]\u001b[A\n",
      " 18%|█▊        | 303104/1648877 [00:01<00:04, 314691.01it/s]\u001b[A\n",
      " 22%|██▏       | 368640/1648877 [00:01<00:03, 364842.84it/s]\u001b[A\n",
      " 26%|██▋       | 434176/1648877 [00:01<00:02, 416195.96it/s]\u001b[A\n",
      " 30%|██▉       | 491520/1648877 [00:01<00:02, 450695.16it/s]\u001b[A\n",
      " 33%|███▎      | 548864/1648877 [00:01<00:02, 479280.74it/s]\u001b[A\n",
      " 37%|███▋      | 606208/1648877 [00:01<00:02, 462469.89it/s]\u001b[A\n",
      " 40%|████      | 663552/1648877 [00:01<00:02, 478041.50it/s]\u001b[A\n",
      " 45%|████▍     | 737280/1648877 [00:01<00:01, 528585.01it/s]\u001b[A\n",
      " 48%|████▊     | 794624/1648877 [00:02<00:01, 452416.83it/s]\u001b[A\n",
      " 52%|█████▏    | 851968/1648877 [00:02<00:02, 389677.34it/s]\u001b[A\n",
      " 57%|█████▋    | 942080/1648877 [00:02<00:01, 465880.52it/s]\u001b[A\n",
      " 61%|██████    | 999424/1648877 [00:02<00:01, 421989.87it/s]\u001b[A\n",
      " 64%|██████▍   | 1056768/1648877 [00:02<00:01, 425219.96it/s]\u001b[A\n",
      " 67%|██████▋   | 1105920/1648877 [00:02<00:01, 426051.91it/s]\u001b[A\n",
      " 70%|███████   | 1155072/1648877 [00:02<00:01, 396578.09it/s]\u001b[A\n",
      " 73%|███████▎  | 1204224/1648877 [00:03<00:01, 413193.04it/s]\u001b[A\n",
      " 76%|███████▌  | 1253376/1648877 [00:03<00:01, 388930.38it/s]\u001b[A\n",
      " 78%|███████▊  | 1294336/1648877 [00:03<00:00, 391800.21it/s]\u001b[A\n",
      " 82%|████████▏ | 1359872/1648877 [00:03<00:00, 440745.23it/s]\u001b[A\n",
      " 85%|████████▌ | 1409024/1648877 [00:03<00:00, 431335.87it/s]\u001b[A\n",
      " 88%|████████▊ | 1458176/1648877 [00:03<00:00, 396509.03it/s]\u001b[A\n",
      " 91%|█████████▏| 1507328/1648877 [00:03<00:00, 359261.40it/s]\u001b[A\n",
      " 95%|█████████▍| 1564672/1648877 [00:03<00:00, 401427.42it/s]\u001b[A\n",
      " 98%|█████████▊| 1613824/1648877 [00:04<00:00, 401247.03it/s]\u001b[A\n",
      "\n",
      "0it [00:00, ?it/s]\u001b[A\u001b[A"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw\n",
      "Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "8192it [00:00, 40719.18it/s]            \u001b[A\u001b[A"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw\n",
      "Processing...\n",
      "Done!\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "9920512it [00:29, 549474.96it/s]                             "
     ]
    }
   ],
   "source": [
    "\n",
    "mnist_trainset = datasets.MNIST(root='./data', \n",
    "                                train=True, \n",
    "                                download=True, \n",
    "                                transform=data_transforms)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Dataset MNIST\n",
       "    Number of datapoints: 60000\n",
       "    Root location: ./data\n",
       "    Split: Train\n",
       "    StandardTransform\n",
       "Transform: Compose(\n",
       "               ToTensor()\n",
       "           )"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "1654784it [00:20, 401247.03it/s]                             \u001b[A"
     ]
    }
   ],
   "source": [
    "mnist_trainset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size=4"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataloader_mnist_train = t_data.DataLoader(mnist_trainset, \n",
    "                                           batch_size=batch_size,\n",
    "                                           shuffle=True\n",
    "                                          )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<torch.utils.data.dataloader.DataLoader at 0x7f08d41703d0>"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataloader_mnist_train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def make_some_noise():\n",
    "    return torch.rand(batch_size,100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([4, 100])"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "make_some_noise().shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "# defining generator class\n",
    "\n",
    "# N.B.: inherits from nn.Module\n",
    "\n",
    "class generator(nn.Module):\n",
    "    \n",
    "    def __init__(self, inp, out):\n",
    "        \n",
    "        super(generator, self).__init__()\n",
    "        \n",
    "        self.net = nn.Sequential(\n",
    "                                 nn.Linear(inp,300),\n",
    "                                 nn.ReLU(inplace=True),\n",
    "                                 nn.Linear(300,1000),\n",
    "                                 nn.ReLU(inplace=True),\n",
    "                                 nn.Linear(1000,800),\n",
    "                                 nn.ReLU(inplace=True),\n",
    "                                 nn.Linear(800,out)\n",
    "                                    )\n",
    "        \n",
    "    def forward(self, x):\n",
    "        x = self.net(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "gen = generator(100,784)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "generator(\n",
       "  (net): Sequential(\n",
       "    (0): Linear(in_features=100, out_features=300, bias=True)\n",
       "    (1): ReLU(inplace=True)\n",
       "    (2): Linear(in_features=300, out_features=1000, bias=True)\n",
       "    (3): ReLU(inplace=True)\n",
       "    (4): Linear(in_features=1000, out_features=800, bias=True)\n",
       "    (5): ReLU(inplace=True)\n",
       "    (6): Linear(in_features=800, out_features=784, bias=True)\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "gen"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "# defining discriminator class\n",
    "\n",
    "class discriminator(nn.Module):\n",
    "    \n",
    "    def __init__(self, inp, out):\n",
    "        \n",
    "        super(discriminator, self).__init__()\n",
    "        \n",
    "        self.net = nn.Sequential(\n",
    "                                 nn.Linear(inp,300),\n",
    "                                 nn.ReLU(inplace=True),\n",
    "                                 nn.Linear(300,300),\n",
    "                                 nn.ReLU(inplace=True),\n",
    "                                 nn.Linear(300,200),\n",
    "                                 nn.ReLU(inplace=True),\n",
    "                                 nn.Linear(200,out),\n",
    "                                 nn.Sigmoid()\n",
    "                                    )\n",
    "        \n",
    "    def forward(self, x):\n",
    "        x = self.net(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "dis = discriminator(784,1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "discriminator(\n",
       "  (net): Sequential(\n",
       "    (0): Linear(in_features=784, out_features=300, bias=True)\n",
       "    (1): ReLU(inplace=True)\n",
       "    (2): Linear(in_features=300, out_features=300, bias=True)\n",
       "    (3): ReLU(inplace=True)\n",
       "    (4): Linear(in_features=300, out_features=200, bias=True)\n",
       "    (5): ReLU(inplace=True)\n",
       "    (6): Linear(in_features=200, out_features=1, bias=True)\n",
       "    (7): Sigmoid()\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dis"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "d_steps = 1\n",
    "g_steps = 1\n",
    "\n",
    "criteriond1 = nn.BCELoss()\n",
    "optimizerd1 = optim.SGD(dis.parameters(), lr=0.001, momentum=0.9)\n",
    "\n",
    "criteriond2 = nn.BCELoss()\n",
    "optimizerd2 = optim.SGD(gen.parameters(), lr=0.001, momentum=0.9)\n",
    "\n",
    "printing_steps = 200\n",
    "\n",
    "epochs = 4000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_img(array,number=None):\n",
    "    array = array.detach()\n",
    "    array = array.reshape(28,28)\n",
    "    \n",
    "    plt.imshow(array,cmap='binary')\n",
    "    plt.xticks([])\n",
    "    plt.yticks([])\n",
    "    if number:\n",
    "        plt.xlabel(number,fontsize='x-large')\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for epoch in range(epochs):\n",
    "    \n",
    "    if epoch%printing_steps==0:\n",
    "        print(epoch)\n",
    "\n",
    "    # training discriminator\n",
    "    for d_step in range(d_steps):\n",
    "        dis.zero_grad()\n",
    "        \n",
    "        # training discriminator on real data\n",
    "        for inp_real,_ in dataloader_mnist_train:\n",
    "            inp_real_x = inp_real\n",
    "            break\n",
    "\n",
    "        inp_real_x = inp_real_x.reshape(4,784)\n",
    "        dis_real_out = dis(inp_real_x)\n",
    "        dis_real_loss = criteriond1(dis_real_out,Variable(torch.ones(batch_size,1)))\n",
    "        dis_real_loss.backward()\n",
    "\n",
    "        # training discriminator on data produced by generator\n",
    "        inp_fake_x_gen = make_some_noise()\n",
    "        dis_inp_fake_x = gen(inp_fake_x_gen).detach() #output from generator is generated\n",
    "        dis_fake_out = dis(dis_inp_fake_x)\n",
    "        dis_fake_loss = criteriond1(dis_fake_out,Variable(torch.zeros(batch_size,1)))\n",
    "        dis_fake_loss.backward()\n",
    "\n",
    "        optimizerd1.step()\n",
    "            \n",
    "    # training generator\n",
    "    for g_step in range(g_steps):\n",
    "        gen.zero_grad()\n",
    "        \n",
    "        #generating data for input for generator\n",
    "        gen_inp = make_some_noise()\n",
    "        \n",
    "        gen_out = gen(gen_inp)\n",
    "        dis_out_gen_training = dis(gen_out)\n",
    "        gen_loss = criteriond2(dis_out_gen_training,Variable(torch.ones(batch_size,1)))\n",
    "        gen_loss.backward()\n",
    "        \n",
    "        optimizerd2.step()\n",
    "        \n",
    "    if epoch%printing_steps==0:\n",
    "        plot_img(gen_out[0])\n",
    "        plot_img(gen_out[1])\n",
    "        plot_img(gen_out[2])\n",
    "        plot_img(gen_out[3])\n",
    "        print(\"\\n\\n\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "... So now we need to turn this into deep speed code ..."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
